{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "from tempfile import gettempdir\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torchvision.models.resnet import resnet50\n",
    "from tqdm import tqdm\n",
    "\n",
    "from l5kit.configs import load_config_data\n",
    "from l5kit.data import LocalDataManager\n",
    "from l5kit.dataset import AgentDataset, EgoDataset\n",
    "from l5kit.dataset.dataloader_builder import build_dataloader\n",
    "from l5kit.rasterization import build_rasterizer\n",
    "from l5kit.evaluation import write_coords_as_csv, compute_mse_error_csv\n",
    "from l5kit.geometry import transform_points\n",
    "from l5kit.visualization import PREDICTED_POINTS_COLOR, TARGET_POINTS_COLOR, draw_trajectory\n",
    "from prettytable import PrettyTable\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Data path and load cfg\n",
    "\n",
    "By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.\n",
    "\n",
    "Then, we load our config file with relative paths and other configurations (rasteriser, training params...)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set env variable for data\n",
    "os.environ[\"L5KIT_DATA_FOLDER\"] = \"PATH_TO_YOUR_DATA\"\n",
    "# get config\n",
    "cfg = load_config_data(\"./agent_motion_config.yaml\")\n",
    "print(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model\n",
    "\n",
    "Our baseline is a simple `resnet50` pretrained on `imagenet`. We must replace the input and the final layer to address our requirements."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model(cfg: Dict) -> torch.nn.Module:\n",
    "    # load pre-trained Conv2D model\n",
    "    model = resnet50(pretrained=True)\n",
    "\n",
    "    # change input size\n",
    "    num_history_channels = (cfg[\"model_params\"][\"history_num_frames\"] + 1) * 2\n",
    "    num_in_channels = 3 + num_history_channels\n",
    "    model.conv1 = nn.Conv2d(\n",
    "        num_in_channels,\n",
    "        model.conv1.out_channels,\n",
    "        kernel_size=model.conv1.kernel_size,\n",
    "        stride=model.conv1.stride,\n",
    "        padding=model.conv1.padding,\n",
    "        bias=False,\n",
    "    )\n",
    "    # change output size\n",
    "    # X, Y  * number of future states\n",
    "    num_targets = 2 * cfg[\"model_params\"][\"future_num_frames\"]\n",
    "    model.fc = nn.Linear(in_features=2048, out_features=num_targets)\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(data, model, device, criterion):\n",
    "    inputs = data[\"image\"].to(device)\n",
    "    targets = data[\"target_positions\"].to(device).reshape(len(data[\"target_positions\"]), -1)\n",
    "    # Forward pass\n",
    "    outputs = model(inputs)\n",
    "    loss = criterion(outputs, targets)\n",
    "    loss = loss.mean()\n",
    "    return loss, outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the data and initialise the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dm = LocalDataManager(None)\n",
    "# ===== INIT DATASETS\n",
    "rasterizer = build_rasterizer(cfg, dm)\n",
    "train_dataloader = build_dataloader(cfg, \"train\", dm, AgentDataset, rasterizer)\n",
    "eval_dataloader = build_dataloader(cfg, \"val\", dm, AgentDataset, rasterizer)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==== INIT MODEL\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = build_model(cfg).to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "criterion = nn.MSELoss(reduction=\"none\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==== TRAIN LOOP\n",
    "tr_it = iter(train_dataloader)\n",
    "progress_bar = tqdm(range(cfg[\"train_params\"][\"max_num_steps\"]))\n",
    "losses_train = []\n",
    "for _ in progress_bar:\n",
    "    try:\n",
    "        data = next(tr_it)\n",
    "    except StopIteration:\n",
    "        tr_it = iter(train_dataloader)\n",
    "        data = next(tr_it)\n",
    "\n",
    "    model.train()\n",
    "    torch.set_grad_enabled(True)\n",
    "    loss, _ = forward(data, model, device, criterion)\n",
    "\n",
    "    # Backward pass\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    losses_train.append(loss.item())\n",
    "    progress_bar.set_description(f\"loss: {loss.item()} loss(avg): {np.mean(losses_train)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation\n",
    "We can now run inference and store predicted and annotated trajectories. \n",
    "\n",
    "In this example, we run it on a single scene from the eval dataset because of computational constraints. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==== EVAL LOOP\n",
    "model.eval()\n",
    "torch.set_grad_enabled(False)\n",
    "losses_eval = []\n",
    "\n",
    "# store information for evaluation\n",
    "future_coords_offsets_pd = []\n",
    "future_coords_offsets_gt = []\n",
    "\n",
    "timestamps = []\n",
    "agent_ids = []\n",
    "progress_bar = tqdm(eval_dataloader)\n",
    "for data in progress_bar:\n",
    "    loss, ouputs = forward(data, model, device, criterion)\n",
    "    losses_eval.append(loss.item())\n",
    "    progress_bar.set_description(f\"Running EVAL, loss: {loss.item()} loss(avg): {np.mean(losses_eval)}\")\n",
    "\n",
    "    future_coords_offsets_pd.append(ouputs.reshape(len(ouputs), -1, 2).cpu().numpy().copy())\n",
    "    future_coords_offsets_gt.append(data[\"target_positions\"].reshape(len(ouputs), -1, 2).cpu().numpy().copy())\n",
    "\n",
    "    timestamps.append(data[\"timestamp\"].numpy().copy())\n",
    "    agent_ids.append(data[\"track_id\"].numpy().copy())\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save results  and perform evaluation\n",
    "After the model has predicted trajectories for our evaluation set, we can save them in a `csv` file. To simulate a complete evaluation session we can also save the ground truth in another `csv` and get the score.\n",
    "\n",
    "We will get `future_num_frames` values, corresponding to the Displacement Error (mean of squared errors / l2-loss between predicted point and ground truth point) for that timestep."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==== COMPUTE CSV\n",
    "pred_path = f\"{gettempdir()}/pred.csv\"\n",
    "gt_path = f\"{gettempdir()}/gt.csv\"\n",
    "\n",
    "write_coords_as_csv(pred_path, future_num_frames=cfg[\"model_params\"][\"future_num_frames\"],\n",
    "                    future_coords_offsets=np.concatenate(future_coords_offsets_pd),\n",
    "                    timestamps=np.concatenate(timestamps),\n",
    "                    agent_ids=np.concatenate(agent_ids))\n",
    "write_coords_as_csv(gt_path, future_num_frames=cfg[\"model_params\"][\"future_num_frames\"],\n",
    "                    future_coords_offsets=np.concatenate(future_coords_offsets_gt),\n",
    "                    timestamps=np.concatenate(timestamps),\n",
    "                    agent_ids=np.concatenate(agent_ids))\n",
    "# get a pretty visualisation of the errors\n",
    "table = PrettyTable(field_names=[\"future step\", \"MSE\"])\n",
    "table.float_format = \".2\"\n",
    "steps = range(1, cfg[\"model_params\"][\"future_num_frames\"] + 1)\n",
    "errors = compute_mse_error_csv(gt_path, pred_path)\n",
    "for step_idx, step_mse in zip(steps, errors):\n",
    "    table.add_row([step_idx, step_mse])\n",
    "print(table)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualise Results\n",
    "We can also visualise some results from the ego (AV) point of view. \n",
    "\n",
    "We can use the `get_frame_indices` function to find a frame with interesting agents."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_agent_dataset = eval_dataloader.dataset.datasets[0].dataset\n",
    "eval_ego_dataset = EgoDataset(cfg, eval_agent_dataset.dataset, rasterizer)\n",
    "\n",
    "# get the first non-zero agents frame\n",
    "frame_number = -1\n",
    "min_agents_to_pick = 4\n",
    "for idx_frame in range(len(eval_ego_dataset.dataset.frames)):\n",
    "    if len(eval_agent_dataset.get_frame_indices(idx_frame)) >= min_agents_to_pick:\n",
    "        frame_number = idx_frame\n",
    "        break\n",
    "if frame_number == -1:\n",
    "    raise ValueError(f\"can't find a frame with at least {min_agents_to_pick} agents in it\")\n",
    "\n",
    "model.eval()\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "# get AV point-of-view frame\n",
    "data_ego = eval_ego_dataset[frame_number]\n",
    "im_ego = rasterizer.to_rgb(data_ego[\"image\"].transpose(1, 2, 0))\n",
    "\n",
    "\n",
    "center = np.asarray(cfg[\"raster_params\"][\"ego_center\"]) * cfg[\"raster_params\"][\"raster_size\"]\n",
    "agent_indices = eval_agent_dataset.get_frame_indices(frame_number)\n",
    "predicted_positions = []\n",
    "target_positions = []\n",
    "\n",
    "for v_index in agent_indices:\n",
    "    data_agent = eval_agent_dataset[v_index]\n",
    "\n",
    "    out_net = model(torch.from_numpy(data_agent[\"image\"]).unsqueeze(0).to(device))\n",
    "    out_pos = out_net[0].reshape(-1, 2).detach().cpu().numpy()\n",
    "    \n",
    "    # store absolute world coordinates\n",
    "    predicted_positions.append(out_pos + data_agent[\"centroid\"][:2])\n",
    "    target_positions.append(data_agent[\"target_positions\"] + data_agent[\"centroid\"][:2])\n",
    "\n",
    "\n",
    "# convert coordinates to AV point-of-view so we can draw them\n",
    "predicted_positions = transform_points(np.concatenate(predicted_positions), data_ego[\"world_to_image\"])\n",
    "target_positions = transform_points(np.concatenate(target_positions), data_ego[\"world_to_image\"])\n",
    "\n",
    "yaws = np.zeros((len(predicted_positions), 1))\n",
    "draw_trajectory(im_ego, predicted_positions, yaws, PREDICTED_POINTS_COLOR)\n",
    "draw_trajectory(im_ego, target_positions, yaws, TARGET_POINTS_COLOR)\n",
    "\n",
    "plt.imshow(im_ego[::-1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
